"""
    Title

    author: wxz
    date: 2021-12-6
    github: https://github.com/xinzwang

    ** Software ISP Hyperparameter
        -- Pipline
            black level correction   params: 2  [black_level, white_level]
            demosaik                 params: 0
            white balance            params: 3  [gain_1, gain_2, gain_3]
            color correction         params: 9  [c_00, c_01, c_02, c_10, c_11, c_12, c_31, c_32,c_33]
            gamma correction         params: 2  [mul, gamma]
            tone curve               params: 4  [mul, kernel_w, kernel_h, sigma]
"""

import time

from blocks import *
from isp.blocks.utils import imshow


class IspHelper(object):
    def __init__(self):
        self.params = None
        self.raw = None
        self.img = None
        return

    def set_params(self, params):
        self.params = params
        return

    def forward_train(self, raw=None):
        assert self.params is not None, print('[ERROR]: please set params to IspHelper ')
        assert raw is not None, print('[ERROR]: please set a Raw Image to IspHelper ')
        p = self.params
        img = black_level_correction(raw, p[0], p[1])
        img = demosaik(img)
        img = channel_gain_white_balance(img, p[2:5])
        img = color_correction(img, p[5:14].reshape(3, 3))
        img = gamma_correction(img, p[14], p[15])
        img = tone_curve(img, p[16], p[17:19], p[19])
        return img

    def forward(self, raw=None, display=False):
        """
        执行ISP处理
        """
        # 0.参数校验
        if self.params is None:
            print('[ERROR]: please set params to IspHelper ')
            return
        if raw is None:
            if self.raw is None:
                print('[ERROR]: please set a Raw Image to IspHelper ')
            else:
                raw = self.raw

        # 1.黑电平矫正
        print('[Black level correction]')
        start = time.perf_counter()
        black_level = self.params[0]
        white_level = self.params[1]
        img = black_level_correction(raw, black_level, white_level)
        end = time.perf_counter()
        print('    runtime:', end - start)
        if display:
            imshow(img, 'black level correction', cmap='gray')

        # 2.去马赛克
        print('[Demosaik]')
        start = time.perf_counter()
        img = demosaik(img)
        end = time.perf_counter()
        print('    runtime:', end - start)
        if display:
            imshow(img, 'Demosaik', cmap='rgb')

        # 3.白平衡
        print('[White balance]')
        start = time.perf_counter()
        channel_gain = self.params['white_balance']['channel_gain']
        img = channel_gain_white_balance(img, channel_gain)
        end = time.perf_counter()
        print('    runtime:', end - start)
        if display:
            imshow(img, 'White balance', cmap='rgb')

        # 4.色彩空间矫正
        print('[Color correction]')
        start = time.perf_counter()
        color_matrix = self.params['color_correction']['color_matrix']
        img = color_correction(img, color_matrix)
        end = time.perf_counter()
        print('    runtime:', end - start)
        if display:
            imshow(img, 'Color correction', cmap='rgb')

        # 5.gamma矫正
        print('[Gamma correction]')
        start = time.perf_counter()
        multiplier = self.params['gamma_correction']['multiplier']
        gamma = self.params['gamma_correction']['gamma']
        img = gamma_correction(img, multiplier, gamma)
        end = time.perf_counter()
        print('    runtime:', end - start)
        if display:
            imshow(img, 'Gamma correction', cmap='rgb')

        # 6.色调曲线矫正
        print('[Tone curve]')
        start = time.perf_counter()
        strength_multiplier = self.params['tone_curve']['strength_multiplier']
        kernel_size = self.params['tone_curve']['kernel_size']
        sigma = self.params['tone_curve']['sigma']
        img = tone_curve(img, strength_multiplier, kernel_size, sigma)
        end = time.perf_counter()
        print('    runtime:', end - start)
        if display:
            imshow(img, 'Tone curve', cmap='rgb')

        self.img = img
        return img

    def load_raw(self, path):
        data = load_from_path(path)
        self.raw = data[0]
        return data

    def imshow(self):
        imshow(self.img, 'After ISP', cmap='rgb')
        return
